Motif using logos in python using matplotlib

This is a proof of concept for plotting motif logos using matplotlib. The coordinates for bases were picked of from seqLogo package.

seqLogo is LGPL(v2.1) licensed and hence so is this notebook.

See LICENCE


In [20]:
%matplotlib inline
from __future__ import division
from Bio import motifs
import seaborn
import matplotlib.pyplot as plt
plt.style.use('seaborn-ticks')
import numpy as np
import matplotlib.patches as mpatches
from matplotlib.collections import PatchCollection
from matplotlib import transforms

 

colors_scheme = {'G': 'orange', 'A': 'red', 'C': 'blue', 'T': 'darkgreen'}

In [13]:
with open("./meme_out/meme.txt") as handle:
    m = motifs.parse(handle, "meme")

In [14]:
def plotA(xstart=0, ystart=0, xscale=1, yscale=1):
    """Plot A"""
    x = np.array([ 0.  ,  0.4 ,  0.6 ,  1.  ,  0.8 ,  0.68,  0.32,  0.2 ,  0.])
    y = np.array([ 0.  ,  1.  ,  1.  ,  0.  ,  0.  ,  0.3 ,  0.3 ,  0.  ,  0.])
    x = xstart + xscale*x
    y = ystart + yscale*y
    triy =  np.array([0.4 ,   0.4 , 0.75, 0.4])
    trix =  np.array([0.36,   0.64, 0.5, 0.36])

    trix = xstart + xscale*trix
    triy = ystart + yscale*triy
    
    return x, y, trix, triy

def plotT(xstart=0, ystart=0, xscale=1, yscale=1):
    """Plot T"""
    x = np.array([ 0. ,  1. ,  1. ,  0.6,  0.6,  0.4,  0.4,  0. ])
    y = np.array([ 1. ,  1. ,  0.9,  0.9,  0. ,  0. ,  0.9,  0.9])
    x = xstart + xscale*x
    y = ystart + yscale*y
    return x,y

def plotG(xstart=0, ystart=0, xscale=1, yscale=1):
    """Plot G"""
    angle1 = np.linspace(0.3+np.pi/2,np.pi,num=100)
    angle2 = np.linspace(np.pi,1.5*np.pi, num=100)
  
    xl1 = 0.5 + 0.5*np.sin(angle1)
    yl1 =0.5 + 0.5*np.cos(angle1)
    xl2 =0.5 + 0.5*np.sin(angle2)
    yl2 =0.5 + 0.5*np.cos(angle2)

    xl = np.concatenate((xl1, xl2))
    yl = np.concatenate((yl1, yl2))

    x = np.concatenate((xl, xl[::-1]))
    y = np.concatenate((yl, 1-yl[::-1]))

    xi1 = 0.5 + 0.35*np.sin(angle1)
    yi1 = 0.5 + 0.35*np.cos(angle1)
    xi1 = xi1[np.where(yi1<=np.max(yl1))]
    yi1 = yi1[np.where(yi1<=np.max(yl1))]
    yi1[0] = np.max(yl1)

    xi2 = 0.5 + 0.35*np.sin(angle2)
    yi2 = 0.5 + 0.35*np.cos(angle2)

    xi = np.concatenate((xi1,xi2))
    yi = np.concatenate((yi1,yi2))

    x1 = np.concatenate((xi,xi[::-1]))
    y1 = np.concatenate((yi,1-yi[::-1]))

    x = np.concatenate((x,x1[::-1]))
    y = np.concatenate((y, y1[::-1]))
    
    r1 = np.max(xl1)

    h1 = 0.4
    xadd = np.array([r1,0.5,0.5,r1-0.2,r1-0.2,r1,r1])
    yadd = np.array([h1,h1,h1-0.1,h1-0.1,0,0,h1])
    
    x = np.concatenate((x[::-1],xadd))
    y = np.concatenate((y[::-1],yadd))

    x = xstart + xscale*x
    y = ystart + yscale*y
    
    return x,y

def plotC(xstart=0, ystart=0, xscale=1, yscale=1):
    """Plot C"""
    
    angle1 = np.linspace(0.3+np.pi/2,np.pi,num=100)
    angle2 = np.linspace(np.pi,1.5*np.pi, num=100)
  
    xl1 = 0.5 + 0.5*np.sin(angle1)
    yl1 = 0.5 + 0.5*np.cos(angle1)
    xl2 = 0.5 + 0.5*np.sin(angle2)
    yl2 = 0.5 + 0.5*np.cos(angle2)

    xl = np.concatenate((xl1, xl2))
    yl = np.concatenate((yl1, yl2))

    x = np.concatenate((xl, xl[::-1]))
    y = np.concatenate((yl, 1-yl[::-1]))

    xi1 = 0.5 + 0.35*np.sin(angle1)
    yi1 = 0.5 + 0.35*np.cos(angle1)
    xi1 = xi1[np.where(yi1<=np.max(yl1))]
    yi1 = yi1[np.where(yi1<=np.max(yl1))]
    yi1[0] = np.max(yl1)

    xi2 = 0.5 + 0.35*np.sin(angle2)
    yi2 = 0.5 + 0.35*np.cos(angle2)

    xi = np.concatenate((xi1,xi2))
    yi = np.concatenate((yi1,yi2))

    x1 = np.concatenate((xi,xi[::-1]))
    y1 = np.concatenate((yi,1-yi[::-1]))

    x = np.concatenate((x,x1[::-1]))
    y = np.concatenate((y, y1[::-1]))
    x = xstart + xscale*x
    y = ystart + yscale*y
    return x,y

In [15]:
def approximate_error(motif):
    """Calculate approximate error"""
    pwm = motif.pwm
    bases = list(pwm.keys())
    n = sum(motif.counts[bases[0]])
    approx_error = (len(bases)-1)/(2 * np.log(2) * n)
    return approx_error


def exact_error(motif):
    """Calculate exact error, using multinomial(na,nc,ng,nt)"""
    ## Super Slow. O(n^3)
    pwm = motif.pwm
    bases = list(pwm.keys())
    na = sum(motif.counts['A'])
    n = na
    nc = 0
    ng = 0
    nt = 0
    done = False
    exact_error = 0
    while not done:
        print (na,nc,ng,nt)
        exact_error += sum([-p*np.log2(p) for p in [na/n, nc/n, ng/n, nt/n]])
        if nt<=0:
            ## iterate inner loop            
            if ng > 0:
                ## g => t
                ng = ng - 1
                nt = nt + 1
            elif nc > 0:
                ## c -> g 
                nc = nc - 1;
                ng = ng + 1;
            else:
                ## a->c
                na = na - 1
                nc = nc + 1
        else:
            if ng > 0:
                ## g => t
                ng = ng - 1 
                nt = nt + 1
            elif nc>0:
                ## c => g; all t -> g
                nc = nc - 1
                ng = nt + 1
                nt = 0
            elif na>0:
                ## a => c; all g,t -> c
                nc = nt + 1
                na = na - 1
                nt = 0
            else:
                done = True
    return exact_correction

def calc_info_matrix(motif, correction_type='approx'):
    """Calculate information matrix with small sample correction"""
    pwm = motif.pwm
    bases = list(pwm.keys())
    if correction_type=='approx':
        error = approximate_error(motif)
    else:
        error = exact_error(motif)
    info_matrix = [2-error+sum([pwm[b][l]*np.nan_to_num(np.log2(pwm[b][l])) for b in bases]) for l in range(0, len(motif))]
    return info_matrix

def calc_relative_information(motif, correction_type='approx'):
    """Calculate relative information matrix"""
    pwm = motif.pwm
    bases = list(pwm.keys())
    if correction_type=='approx':
        info_matrix = calc_info_matrix(motif)
    else:
        info_matrix = calc_info_matrix(motif, 'exact')
    relative_info = {base: [prob*info for prob,info in zip(pwm[base], info_matrix)]  for base in bases}
    return relative_info

In [16]:
motif = m[0]
rel_info = calc_relative_information(motif)
bases = ['A', 'T', 'G', 'C']
fig = plt.figure()
ax = fig.add_subplot(111)
patches = []
white_spots = []
colors = []
xshift = 0
for i in range(0, len(motif)):
    scores = [(b,rel_info[b][i]) for b in bases]
    scores.sort(key=lambda t: t[1])
    yshift = 0
    for base, score in scores:
        if base=='A':
            X, Y, triX, triY = plotA(xstart=xshift, ystart=yshift, yscale=score) 
            white_spots.append((triX, triY))
            colors.append(colors_scheme['A'])
        elif base=='C':
            X, Y = plotC(xstart=xshift, ystart=yshift, yscale=score) 
            colors.append(colors_scheme['C'])
        elif base=='T':
            X, Y = plotT(xstart=xshift, ystart=yshift, yscale=score) 
            colors.append(colors_scheme['T'])
        elif base=='G':
            X, Y = plotG(xstart=xshift, ystart=yshift, yscale=score) 
            colors.append(colors_scheme['G'])
        yshift +=score
        plot = mpatches.Polygon(zip(X,Y), closed=True, fill=None)
        patches.append(plot)
    xshift+=1
collection = PatchCollection(patches, edgecolors='none', facecolors=colors)
ax.add_collection(collection)
for triX, triY in white_spots:
    ax.fill_between(triX, triY, color="white", linewidth=0.0)

ax.set_xticks(np.arange(0.5,len(motif)+0.5,1))
ax.set_xticklabels(range(1, len(motif)+1))
seaborn.despine(ax=ax, offset=10, trim=True)
plt.show()


/home/saket/anaconda2/envs/clipseq/lib/python3.5/site-packages/ipykernel/__main__.py:66: RuntimeWarning: divide by zero encountered in log2
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-16-e2f433f4e563> in <module>()
     27             colors.append(colors_scheme['G'])
     28         yshift +=score
---> 29         plot = mpatches.Polygon(zip(X,Y), closed=True, fill=None)
     30         patches.append(plot)
     31     xshift+=1

/home/saket/anaconda2/envs/clipseq/lib/python3.5/site-packages/matplotlib/patches.py in __init__(self, xy, closed, **kwargs)
    898         Patch.__init__(self, **kwargs)
    899         self._closed = closed
--> 900         self.set_xy(xy)
    901 
    902     def get_path(self):

/home/saket/anaconda2/envs/clipseq/lib/python3.5/site-packages/matplotlib/patches.py in set_xy(self, xy)
    962         xy = np.asarray(xy)
    963         if self._closed:
--> 964             if len(xy) and (xy[0] != xy[-1]).any():
    965                 xy = np.concatenate([xy, [xy[0]]])
    966         else:

TypeError: len() of unsized object

MEME Logo:

A more scalable approach

Credits: Markus Piotrowski See: https://github.com/biopython/biopython/issues/850#issuecomment-225708297


In [10]:
## Credits Markus Piotrowski
## See: https://github.com/biopython/biopython/issues/850#issuecomment-225708297

import matplotlib.patheffects
import matplotlib.pyplot as plt
import matplotlib.transforms as mtrans
from matplotlib import transforms

class Scale(matplotlib.patheffects.RendererBase):
    def __init__(self, sx, sy=None):
        self._sx = sx
        self._sy = sy

    def draw_path(self, renderer, gc, tpath, affine, rgbFace):
        affine = affine.identity().scale(self._sx, self._sy)+affine
        renderer.draw_path(gc, tpath, affine, rgbFace)

In [87]:
fig = plt.figure()
#ax.set_xticks(range(0.1,0.1*len(motif),0.1))
#ax.set_xticklabels(range(1,len(motif)+1))

fig.set_size_inches(len(motif),2.5)
ax = fig.add_subplot(111)

xshift = 0
trans_offset = transforms.offset_copy(ax.transAxes, 
                                  fig=fig, 
                                  x=0, 
                                  y=0, 
                                  units='points')
for i in range(0, len(motif)):
    scores = [(b,rel_info[b][i]) for b in bases]
    scores.sort(key=lambda t: t[1])
    yshift = 0
    for base, score in scores:
        txt = ax.text(0, 
                      0, 
                      base, 
                      transform=trans_offset,
                      fontsize=80, 
                      color=colors_scheme[base],
                      weight='bold',
                      ha='center',
                      family='sans-serif'
                      )
        txt.set_path_effects([Scale(1.0, score)])
        fig.canvas.draw()
        window_ext = txt.get_window_extent(txt._renderer)
        yshift = window_ext.height*score
        trans_offset = transforms.offset_copy(txt._transform, fig=fig, y=yshift, units='points')
    xshift+=window_ext.width
    trans_offset = mtrans.offset_copy(ax.transAxes, fig=fig, x=xshift, units='points')

ax.set_yticks(range(0,3))
#ax.set_xticks(np.arange(-window_ext.width/100,len(motif)*window_ext.width/100+0.5,window_ext.width/100))

ax.set_xticks(range(len(motif)))
ax.set_xticklabels(range(1,len(motif)+1))
#ax.set_xticklabels(ax.transData.transform([(i,0) for i in range(len(motif))]))
ax.set_yticklabels(np.arange(0,3,1))
#seaborn.despine(ax=ax, offset=30, trim=True)

plt.show()


MEME Logo:


In [72]:
ax.transData.transform((5, 0))
ax.transData.transform([(i,0) for i in range(len(motif))])


Out[72]:
array([[  190.8       ,    22.5       ],
       [  287.42337662,    22.5       ],
       [  384.04675325,    22.5       ],
       [  480.67012987,    22.5       ],
       [  577.29350649,    22.5       ],
       [  673.91688312,    22.5       ],
       [  770.54025974,    22.5       ],
       [  867.16363636,    22.5       ],
       [  963.78701299,    22.5       ],
       [ 1060.41038961,    22.5       ],
       [ 1157.03376623,    22.5       ],
       [ 1253.65714286,    22.5       ],
       [ 1350.28051948,    22.5       ],
       [ 1446.9038961 ,    22.5       ],
       [ 1543.52727273,    22.5       ]])

In [28]:
from matplotlib.transforms import offset_copy

In [30]:
print(offset_copy.__doc__)


    Return a new transform with an added offset.
      args:
        trans is any transform
      kwargs:
        fig is the current figure; it can be None if units are 'dots'
        x, y give the offset
        units is 'inches', 'points' or 'dots'
    

In [ ]: